{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp distributed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.callback.progress import ProgressCallback\n",
    "from torch.nn.parallel import DistributedDataParallel, DataParallel\n",
    "from fastai.data.load import _FakeLoader,_loaders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed and parallel training\n",
    "\n",
    "> Callbacks and helper functions to train in parallel or use distributed training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parallel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Patch the parallel models so they work with RNNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def reset(self: DataParallel):\n",
    "    if hasattr(self.module, 'reset'): self.module.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ParallelTrainer(Callback):\n",
    "    run_after,run_before = TrainEvalCallback,Recorder\n",
    "    def __init__(self, device_ids): self.device_ids = device_ids\n",
    "    def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)\n",
    "    def after_fit(self): self.learn.model = self.learn.model.module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def to_parallel(self: Learner, device_ids=None):\n",
    "    self.add_cb(ParallelTrainer(device_ids))\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def detach_parallel(self: Learner):\n",
    "    \"Remove ParallelTrainer callback from Learner.\"\n",
    "    self.remove_cb(ParallelTrainer)\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@contextmanager\n",
    "def parallel_ctx(self: Learner, device_ids=None):\n",
    "    \"A context manager to adapt a learner to train in data parallel mode.\"\n",
    "    try:\n",
    "        self.to_parallel(device_ids)\n",
    "        yield self\n",
    "    finally:\n",
    "        self.detach_parallel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distributed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Patch the parallel models so they work with RNNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def reset(self: DistributedDataParallel):\n",
    "    if hasattr(self.module, 'reset'): self.module.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convenience functions to set up/tear down torch distributed data parallel mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def setup_distrib(gpu=None):\n",
    "    if gpu is None: return gpu\n",
    "    gpu = int(gpu)\n",
    "    torch.cuda.set_device(int(gpu))\n",
    "    if num_distrib() > 0:\n",
    "        torch.distributed.init_process_group(backend='nccl', init_method='env://')\n",
    "    return gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def teardown_distrib():\n",
    "    if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to change the dataloaders so that they only get one part of the batch each (otherwise there is no point in using distributed training)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class DistributedDL(TfmdDL):\n",
    "  \n",
    "    _round_to_multiple=lambda number,multiple:int(math.ceil(number/multiple)*multiple)\n",
    "    \n",
    "    def _broadcast(self,t,rank):\n",
    "        \"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call.\"\n",
    "        t = LongTensor(t).cuda() # nccl only works with cuda tensors\n",
    "        torch.distributed.broadcast(t,rank)\n",
    "        return t.cpu().tolist()\n",
    "    def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test\n",
    "    def __len__(self):\n",
    "        return DistributedDL._round_to_multiple(len(self.dl),self.world_size)//self.world_size\n",
    "    def get_idxs(self):\n",
    "        idxs = self.dl.get_idxs()       # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent)\n",
    "        idxs = self._broadcast(idxs,0)  # broadcast and receive it from rank 0 to all\n",
    "        self.n = len(idxs)              # we assumed n was dl.n but we really care about number of idxs\n",
    "        # add extra samples to make it evenly divisible\n",
    "        self.n_padded = DistributedDL._round_to_multiple(self.n,self.world_size)\n",
    "        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n\n",
    "        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors\n",
    "        return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size]\n",
    "    def before_iter(self): \n",
    "        self.i = 0\n",
    "        self.dl.before_iter()\n",
    "    def randomize(self): self.dl.randomize()\n",
    "    def after_batch(self,b):\n",
    "        self.i += find_bs(b)\n",
    "        return self.dl.after_batch(b)\n",
    "    def after_iter(self): \n",
    "        self.dl.after_iter()\n",
    "    def create_batches(self,samps): return self.dl.create_batches(samps)\n",
    "    def to_detach(self,b, cpu=True, gather=True):\n",
    "        b = self._to_detach(b, cpu, gather)\n",
    "        def _inner(b):\n",
    "            if b.ndim>0:\n",
    "                # for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering\n",
    "                n = sum([min(0,max(-len(b)//self.world_size,\n",
    "                                   self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)])\n",
    "                b = b[:n or None]\n",
    "            return b\n",
    "        return apply(_inner,b) if gather and hasattr(self,'i') and hasattr(self,'n') and hasattr(self,'n_padded') else b\n",
    "    def __init__(self,dl,rank,world_size):\n",
    "        store_attr('dl,rank,world_size')\n",
    "        self.bs,self.device,self.drop_last,self.dataset = dl.bs,dl.device,dl.drop_last,dl.dataset\n",
    "        fake = dl.fake_l\n",
    "        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "_tmp_file = tempfile.NamedTemporaryFile().name # i tried putting this inside self / _broadcast to no avail\n",
    "# patch _broadcast with a mocked version so we can test DistributedDL w/o a proper DDP setup\n",
    "@patch\n",
    "def _broadcast(self:DistributedDL,t,rank):\n",
    "    t = LongTensor(t)\n",
    "    if rank == self.rank: torch.save(t,_tmp_file)\n",
    "    else:                 t.data = torch.load(_tmp_file)\n",
    "    return t.tolist()\n",
    "# patch _to_detach with a mocked version that will return right gathered size but -100 for other rank tensors\n",
    "@patch\n",
    "def _to_detach(self:DistributedDL,b,cpu=True,gather=True):\n",
    "    b = to_detach(b,cpu,gather)\n",
    "    if not gather: return b\n",
    "    def _inner(b, cpu, gather):\n",
    "        if b.ndim == 0: b=b[None]\n",
    "        b = torch.cat([b if i==self.rank else torch.full_like(b,-100) for i in range(self.world_size)])\n",
    "        return b if b.ndim > 0 else b.mean()\n",
    "    return apply(_inner,b,cpu,gather)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(range(50)), bs=12, num_workers=2)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), (torch.arange(i*13, i*13+12)%50,torch.tensor([i*13+12])%50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(zip(range(50),range(100,150))), bs=12, num_workers=4)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), [(torch.arange(i*13, i*13+12)%50,100+torch.arange(i*13, i*13+12)%50),\n",
    "                        ((torch.tensor([i*13+12])%50),100+torch.tensor([i*13+12])%50)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(range(50)), bs=12, num_workers=2,drop_last=True)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), [torch.arange(i*13, i*13+12)%50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(zip(range(12),range(100,112))), bs=12, num_workers=4)\n",
    "res,dls = [],[]\n",
    "for i in range(5): dls.append(DistributedDL(dl, i, 5))\n",
    "for b in zip(*dls):\n",
    "    for r in range(5):\n",
    "        d=L(dls[r].to_detach(b[r]))\n",
    "        test_eq(d.map(lambda x:(x!=-100).sum().item()),(3,3) if r!=4 else (0,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(range(10)), bs=4, num_workers=2, shuffle=True)\n",
    "res = []\n",
    "for i in range(3):\n",
    "    dl1 = DistributedDL(dl, i, 3)\n",
    "    b  = list(dl1)[0]\n",
    "    bd = dl1.to_detach(b)\n",
    "    test_eq(b[:None if i<2 else 2],bd[4*i:4*(i+1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.callback.data import WeightedDL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = WeightedDL(list(range(50)), bs=16, num_workers=2, shuffle=True,wgts=list(np.arange(50)>=25))\n",
    "res = []\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    res += list(dl1)[0].tolist()\n",
    "test(res,[25]*len(res),operator.ge)        # all res >=25\n",
    "test(res,[25]*len(res),lambda a,b: ~(a<b)) # all res NOT < 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class DistributedTrainer(Callback):\n",
    "    run_after,run_before = TrainEvalCallback,Recorder\n",
    "    fup = None # for `find_unused_parameters` in DistributedDataParallel()\n",
    "    def __init__(self, cuda_id=0,sync_bn=True): store_attr('cuda_id,sync_bn')\n",
    "    def before_fit(self):\n",
    "        opt_kwargs = { 'find_unused_parameters' : DistributedTrainer.fup } if DistributedTrainer.fup is not None else {}\n",
    "        self.learn.model = DistributedDataParallel(\n",
    "            nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.sync_bn else self.model,\n",
    "            device_ids=[self.cuda_id], output_device=self.cuda_id, **opt_kwargs)\n",
    "        self.old_dls = list(self.dls)\n",
    "        self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]\n",
    "        if rank_distrib() > 0: self.learn.logger=noop\n",
    "\n",
    "    def _wrap_dl(self, dl):\n",
    "        return dl if isinstance(dl, DistributedDL) else DistributedDL(dl, rank_distrib(), num_distrib())\n",
    "\n",
    "    def before_train(self):    self.learn.dl = self._wrap_dl(self.learn.dl)\n",
    "    def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl)\n",
    "\n",
    "    def after_fit(self):\n",
    "        self.learn.model = self.learn.model.module\n",
    "        self.learn.dls.loaders = self.old_dls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Attach, remove a callback which adapts the model to use DistributedDL to train in distributed data parallel mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def to_distributed(self: Learner, cuda_id,sync_bn=True):\n",
    "    self.add_cb(DistributedTrainer(cuda_id,sync_bn))\n",
    "    if rank_distrib() > 0: self.remove_cb(ProgressCallback)\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def detach_distributed(self: Learner):\n",
    "    if num_distrib() <=1: return self\n",
    "    self.remove_cb(DistributedTrainer)\n",
    "    if rank_distrib() > 0 and not hasattr(self, 'progress'): self.add_cb(ProgressCallback())\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@contextmanager\n",
    "def distrib_ctx(self: Learner, cuda_id=None,sync_bn=True):\n",
    "    \"A context manager to adapt a learner to train in distributed data parallel mode.\"\n",
    "    # Figure out the GPU to use from rank.  Create a dpg if none exists yet.\n",
    "    if cuda_id is None: cuda_id = rank_distrib()\n",
    "    if not torch.distributed.is_initialized():\n",
    "        setup_distrib(cuda_id)\n",
    "        cleanup_dpg =   torch.distributed.is_initialized()\n",
    "    else: cleanup_dpg = False\n",
    "    # Adapt self to DistributedDataParallel, yield, and cleanup afterwards.\n",
    "    try:\n",
    "        if num_distrib() > 1: self.to_distributed(cuda_id,sync_bn)\n",
    "        yield self\n",
    "    finally:\n",
    "        self.detach_distributed()\n",
    "        if cleanup_dpg: teardown_distrib()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `distrib_ctx` context manager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**`distrib_ctx(cuda_id)`** prepares a learner to train in distributed data parallel mode.  It assumes these [environment variables](https://pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods) have all been setup properly, such as those launched by [`python -m fastai.launch`](https://github.com/fastai/fastai/blob/master/fastai/launch.py).\n",
    "\n",
    "Typical usage:\n",
    "\n",
    "```\n",
    "with learn.distrib_ctx(): learn.fit(.....)\n",
    "```\n",
    "\n",
    "It attaches a `DistributedTrainer` callback and `DistributedDL` data loader to  the learner, then executes `learn.fit(.....)`.  Upon exiting the context, it removes the `DistributedTrainer` and `DistributedDL`, and destroys any locally created distributed process group.  The process is still attached to the GPU though."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def rank0_first(func):\n",
    "    \"Execute `func` in the Rank-0 process first, then in other ranks in parallel.\"\n",
    "    dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0)\n",
    "    with dummy_l.distrib_ctx():\n",
    "        if rank_distrib() == 0: res = func()\n",
    "        distrib_barrier()\n",
    "        if rank_distrib() != 0: res = func()\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**`rank0_first(f)`** calls `f()` in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, `f()` is called only once as expected.\n",
    "\n",
    "One application of `rank0_first()` is to make fresh downloads via `untar_data()` safe in distributed training scripts launched by `python -m fastai.launch <script>`:\n",
    "\n",
    "> <code>path = untar_data(URLs.IMDB)</code>\n",
    "\n",
    "becomes:\n",
    "\n",
    "> <code>path = <b>rank0_first(lambda:</b> untar_data(URLs.IMDB))</code>\n",
    "\n",
    "\n",
    "Some learner factory methods may use `untar_data()` to **download pretrained models** by default:\n",
    "\n",
    "> <code>learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)</code>\n",
    "\n",
    "becomes:\n",
    "\n",
    "> <code>learn = <b>rank0_first(lambda:</b> text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))</code>\n",
    "\n",
    "Otherwise, multiple processes will download at the same time and corrupt the data.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
